# -*- coding: utf-8 -*-
"""
Created on Tue May  6 21:30:38 2025

@author: bramv
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
from tqdm import tqdm
import openpyxl

# PARAMETERS
years = np.arange(2025, 2051)
n_sim = 10000

def make_truncnorm(mean, std, lower, upper):
    a = (lower - mean) / std
    b = (upper - mean) / std
    return truncnorm(a, b, loc=mean, scale=std)

#3 UNCERTAINTIES (3.1 Deployment AEC & PEM of the repor) 

#UNCERTAINTY 1: INITIAL ELECTROLYSER CAPACITY 2025 (NORMAL DISTRIBUTION)
C0_dist = make_truncnorm(mean=4.30, std=4.35, lower=0.66, upper=17.70) 

# Sample from the distribution
samples_C0 = C0_dist.rvs(size=10000)

# Plot the histogram (Figure 4.1)
plt.figure(figsize=(8, 5))
plt.hist(samples_C0, bins=50, density=True, color='lightblue', edgecolor='black')
plt.title("Initial Electrolyzer Capacity Distribution AEC (C₀, 2025)")
plt.xlabel("Capacity [GW]", fontsize=16)
plt.ylabel("Probability Density", fontsize=16)
plt.tick_params(axis='both', labelsize=16)
plt.grid(True)
plt.tight_layout()
plt.show()

#UNCERTAINTY 2: GROWTH YEAR %/YEAR (NORMAL DISTRIBUTION) 
b_dist = make_truncnorm(mean=39, std=13.75, lower=15, upper=70)

# Sample from the distribution
samples_b = b_dist.rvs(size=10000)

# Plot the histogram (Figure 4.2)
plt.figure(figsize=(8, 5))
plt.hist(samples_b, bins=50, density=True, color='mediumorchid', edgecolor='black')
plt.title("Growth Rate Distribution (b) [%/year]")
plt.xlabel("Growth Rate [%/year]")
plt.ylabel("Probability Density")
plt.grid(True)
plt.tight_layout()
plt.show()

# UNCERTAINTY 3: DEMAND TARGETS 2030 AND 2050 (ANTICIPATION LEVEL 0, 5, 10) 
anticipation = 5

demand_targets = {
    2030: 398,    # NZE target AEC 2030 (3.1.2. Reference Case Deployment)
    2050: 1137    # NZE target AEC 2050 (3.1.2. Reference Case Deployment)
}


#%%

# Interpolate demand pull trajectory
def build_Cmax(years, milestones):
    sorted_years = sorted(milestones.keys())
    cmax = np.zeros_like(years, dtype=float)
    for i in range(len(sorted_years) - 1):
        t1, t2 = sorted_years[i], sorted_years[i+1]
        c1, c2 = milestones[t1], milestones[t2]
        idx = (years >= t1) & (years <= t2)
        cmax[idx] = np.interp(years[idx], [t1, t2], [c1, c2])
    cmax[years < sorted_years[0]] = milestones[sorted_years[0]]
    cmax[years > sorted_years[-1]] = milestones[sorted_years[-1]]
    return cmax

# Logistic forecast model
def logistic_forecast(C0, b, Cmax_series):
    C_t = [C0]
    for t in range(1, len(Cmax_series)):
        Ct = C_t[-1]
        Cmax = Cmax_series[t]
        Ct_next = Ct + b * Ct * (1 - Ct / Cmax)
        C_t.append(Ct_next)
    return np.array(C_t)



#%% Monte Carlo simulation (3.3. Monte Carlo Simulation)
sim_results = []

for _ in tqdm(range(n_sim)):
    C0 = C0_dist.rvs()
    b = b_dist.rvs() / 100

    targets = {
    2025: C0,
    2030: demand_targets[2030],
    2050: demand_targets[2050]
}

    shifted_targets = {year - anticipation: val for year, val in targets.items()}

    Cmax_series = build_Cmax(years, shifted_targets)

    if Cmax_series[0] <= 0:
        continue

    C0 = min(C0, Cmax_series[0])
    try:
        C_t = logistic_forecast(C0, b, Cmax_series)
        sim_results.append(C_t)
    except:
        continue

sim_results = np.array(sim_results)
#%% Plotting
p5, p50, p95 = np.percentile(sim_results, [5, 50, 95], axis=0)

shifted_median_targets = {year - anticipation: val for year, val in targets.items()}
Cmax_t_median = build_Cmax(years, shifted_median_targets)

plt.figure(figsize=(10, 6))
plt.fill_between(years, p5, p95, alpha=0.3, label='90% range')
plt.plot(years, p50, label='Median forecast', color='blue')
plt.plot(years, Cmax_t_median, '--', label='Demand pull', color='black')
plt.xlabel('Year')
plt.ylabel('Electrolyzer Capacity [GW]')
plt.title('Logistic Forecast with Uncertain Demand Pull (2030 and 2050 Targets Only)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

#%% Plot: 50% and 95% Confidence Intervals (Figure 4.3 Probabilistic S-curve)
plt.figure(figsize=(10, 6))

# Calculate 25th and 75th percentiles for 50% interval
p25, p75 = np.percentile(sim_results, [25, 75], axis=0)

# Fill intervals
plt.fill_between(years, p5, p95, alpha=0.3, label='95% interval', color='skyblue')
plt.fill_between(years, p25, p75, alpha=0.5, label='50% interval', color='deepskyblue')

# Plot median and demand pull
plt.plot(years, p50, label='Median forecast', color='blue')
plt.plot(years, Cmax_t_median, '--', label='Demand pull', color='black')

# Axis labels and title
plt.xlabel('Year',fontsize=16)
plt.ylabel('Electrolyzer Capacity [GW]',fontsize=16)
plt.title('Logistic Forecast AEC: 50% and 95% Confidence Intervals',fontsize=14)
plt.legend(fontsize=16)
plt.grid(True)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.tight_layout()
plt.show()

#%% Plot: Individual Simulation Trajectorie
plt.figure(figsize=(10, 6))

# Limit to first 500 simulations for readability
for sim in sim_results[:500]:
    plt.plot(years, sim, color='lightgray', alpha=0.15)

# Plot median and demand pull
plt.plot(years, p50, color='blue', label='Median forecast')
plt.plot(years, Cmax_t_median, '--', color='black', label='Demand pull')

# Axis labels and title
plt.xlabel('Year')
plt.ylabel('Electrolyzer Capacity [GW]')
plt.title('Individual Simulation Trajectories (500 samples)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

#%%Load the existing Excel workbook
file_path = "DEPLOYMENT FORECAST AEC.xlsx"
wb = openpyxl.load_workbook(file_path)
ws = wb.active 

# Starting row (Excel is 1-indexed)
start_row = 27

#  Write data based on deployment trajacetory scenario (percentile) 
percentile = p50
for i, (year, forecast) in enumerate(zip(years, percentile), start=start_row):
    ws.cell(row=i, column=1, value=year)
    ws.cell(row=i, column=2, value=round(forecast, 2))  # Write as float, let Excel handle comma/period
    ws.cell(row=i, column=3, value="forecast")

# Save workbook
wb.save(file_path)
print("Data successfully written to Excel.")

#%%
#%% Print key forecast statistics for 2030 and 2050
#years_of_interest = [2030, 2040, 2050]
#for year in years_of_interest:
#    idx = np.where(years == year)[0][0]
#    print(f"\nForecast statistics for {year}:")
#    print(f"  - Median (p50):     {p50[idx]:.2f} GW")
#    print(f"  - 50% interval:     {p25[idx]:.2f} – {p75[idx]:.2f} GW")
#    print(f"  - 95% interval:     {p5[idx]:.2f} – {p95[idx]:.2f} GW")


#%%
#%% Print full yearly forecast statistics (95% and 50% intervals + median)
#print("\nFull Yearly Forecast Statistics:")
#print("Year\tp5 [GW]\tp25 [GW]\tp50 [GW]\tp75 [GW]\tp95 [GW]")
#for i, year in enumerate(years):
#    print(f"{year}\t{p5[i]:.2f}\t{p25[i]:.2f}\t{p50[i]:.2f}\t{p75[i]:.2f}\t{p95[i]:.2f}")
